{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/VincentStimper/normalizing-flows/blob/master/examples/glow_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Image generation with Glow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we show how a flow can be trained to generate images with the `normflows` package. The flow is a class-conditional [Glow](https://arxiv.org/abs/1807.03039) model, which is based on the [multi-scale architecture](https://arxiv.org/abs/1605.08803). This Glow model is applied to the [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Perparation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To get started, we have to install the `normflows` package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install normflows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import required packages\n",
    "import torch\n",
    "import torchvision as tv\n",
    "import numpy as np\n",
    "import normflows as nf\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we imported the necessary packages, we create a flow model. Glow consists of `nf.flows.GlowBlocks`, that are arranged in a `nf.MultiscaleFlow`, following the multi-scale architecture. The base distribution is a `nf.distributions.ClassCondDiagGaussian`, which is a diagonal Gaussian with mean and standard deviation dependent on the class label of the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Set up model\n",
    "\n",
    "# Define flows\n",
    "L = 3\n",
    "K = 16\n",
    "torch.manual_seed(0)\n",
    "\n",
    "input_shape = (3, 32, 32)\n",
    "n_dims = np.prod(input_shape)\n",
    "channels = 3\n",
    "hidden_channels = 256\n",
    "split_mode = 'channel'\n",
    "scale = True\n",
    "num_classes = 10\n",
    "\n",
    "# Set up flows, distributions and merge operations\n",
    "q0 = []\n",
    "merges = []\n",
    "flows = []\n",
    "for i in range(L):\n",
    "    flows_ = []\n",
    "    for j in range(K):\n",
    "        flows_ += [nf.flows.GlowBlock(channels * 2 ** (L + 1 - i), hidden_channels,\n",
    "                                     split_mode=split_mode, scale=scale)]\n",
    "    flows_ += [nf.flows.Squeeze()]\n",
    "    flows += [flows_]\n",
    "    if i > 0:\n",
    "        merges += [nf.flows.Merge()]\n",
    "        latent_shape = (input_shape[0] * 2 ** (L - i), input_shape[1] // 2 ** (L - i), \n",
    "                        input_shape[2] // 2 ** (L - i))\n",
    "    else:\n",
    "        latent_shape = (input_shape[0] * 2 ** (L + 1), input_shape[1] // 2 ** L, \n",
    "                        input_shape[2] // 2 ** L)\n",
    "    q0 += [nf.distributions.ClassCondDiagGaussian(latent_shape, num_classes)]\n",
    "\n",
    "\n",
    "# Construct flow model with the multiscale architecture\n",
    "model = nf.MultiscaleFlow(q0, flows, merges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Move model on GPU if available\n",
    "enable_cuda = True\n",
    "device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With `torchvision` we can download the CIFAR-10 dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare training data\n",
    "batch_size = 128\n",
    "\n",
    "transform = tv.transforms.Compose([tv.transforms.ToTensor(), nf.utils.Scale(255. / 256.), nf.utils.Jitter(1 / 256.)])\n",
    "train_data = tv.datasets.CIFAR10('datasets/', train=True,\n",
    "                                 download=True, transform=transform)\n",
    "train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,\n",
    "                                           drop_last=True)\n",
    "\n",
    "test_data = tv.datasets.CIFAR10('datasets/', train=False,\n",
    "                                download=True, transform=transform)\n",
    "test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)\n",
    "\n",
    "train_iter = iter(train_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, can train the model on the image data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Train model\n",
    "max_iter = 20000\n",
    "\n",
    "loss_hist = np.array([])\n",
    "\n",
    "optimizer = torch.optim.Adamax(model.parameters(), lr=1e-3, weight_decay=1e-5)\n",
    "\n",
    "for i in tqdm(range(max_iter)):\n",
    "    try:\n",
    "        x, y = next(train_iter)\n",
    "    except StopIteration:\n",
    "        train_iter = iter(train_loader)\n",
    "        x, y = next(train_iter)\n",
    "    optimizer.zero_grad()\n",
    "    loss = model.forward_kld(x.to(device), y.to(device))\n",
    "        \n",
    "    if ~(torch.isnan(loss) | torch.isinf(loss)):\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    loss_hist = np.append(loss_hist, loss.detach().to('cpu').numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.plot(loss_hist, label='loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To evaluate our model, we first draw samples from our model. When sampling, we can specify the classes, so we draw then samples from each class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Model samples\n",
    "num_sample = 10\n",
    "\n",
    "with torch.no_grad():\n",
    "    y = torch.arange(num_classes).repeat(num_sample).to(device)\n",
    "    x, _ = model.sample(y=y)\n",
    "    x_ = torch.clamp(x, 0, 1)\n",
    "    plt.figure(figsize=(10, 10))\n",
    "    plt.imshow(np.transpose(tv.utils.make_grid(x_, nrow=num_classes).cpu().numpy(), (1, 2, 0)))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For quantitative evaluation, we can compute the bits per dim of our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get bits per dim\n",
    "n = 0\n",
    "bpd_cum = 0\n",
    "with torch.no_grad():\n",
    "    for x, y in iter(test_loader):\n",
    "        nll = model(x.to(device), y.to(device))\n",
    "        nll_np = nll.cpu().numpy() \n",
    "        bpd_cum += np.nansum(nll_np / np.log(2) / n_dims + 8)\n",
    "        n += len(x) - np.sum(np.isnan(nll_np))\n",
    "        \n",
    "    print('Bits per dim: ', bpd_cum / n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that to get competitive performance, a much larger model then specified in this notebook, which is trained over 100 thousand to 1 million iterations, is needed."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
